#include "lenet.h"

#define TRAIN_BATCH 100

#define FILE_TRAIN_IMAGE        "train-images-idx3-ubyte"
#define FILE_TRAIN_LABEL        "train-labels-idx1-ubyte"
#define FILE_TEST_IMAGE        "t10k-images-idx3-ubyte"
#define FILE_TEST_LABEL        "t10k-labels-idx1-ubyte"
#define LENET_FILE         "model.dat"
#define COUNT_TRAIN        60000
#define COUNT_TEST        10000

typedef uint8_t image[28][28];

#define VkResultCheck(x) { VkResult res = (x); if(VK_SUCCESS != res) return res; }

static inline void LoadInputToFeature(Feature* features, image input)
{
    float(*layer0)[LENGTH_FEATURE0][LENGTH_FEATURE0] = features->layer0;
    const long sz = sizeof(image) / sizeof(**input);
    float mean = 0, std = 0;
    for (int j = 0; j < sizeof(image) / sizeof(*input); ++j)
        for (int k = 0; k < sizeof(*input) / sizeof(**input); ++k)
        {
            mean += input[j][k];
            std += input[j][k] * input[j][k];
        }
    mean /= sz;
    std = sqrtf(std / sz - mean * mean);
    for (int j = 0; j < sizeof(image) / sizeof(*input); ++j)
        for (int k = 0; k < sizeof(*input) / sizeof(**input); ++k)
        {
            layer0[0][j + PADDING][k + PADDING] = (input[j][k] - mean) / std;
        }
}

static uint8_t GetMaxResult(float* output, uint32_t count)
{
    uint8_t result = 0;
    for (uint8_t i = 1; i < count; ++i)
        result += (i - result) * (output[i] > output[result]);
    return result;
}

static int SaveModelToFile(LeNet5* lenet, const char filename[])
{
    FILE* fp = fopen(filename, "wb");
    if (!fp) return 1;
    fwrite(lenet, sizeof(LeNet5), 1, fp);
    fclose(fp);
    return 0;
}

static int LoadModelFromFile(LeNet5* lenet, const char filename[])
{
    FILE* fp = fopen(filename, "rb");
    if (!fp) return 1;
    fread(lenet, sizeof(LeNet5), 1, fp);
    fclose(fp);
    return 0;
}

static int LoadDataFromFile(unsigned char(*data)[28][28], unsigned char label[], const int count, const char data_file[], const char label_file[])
{
    FILE* fp_image = fopen(data_file, "rb");
    FILE* fp_label = fopen(label_file, "rb");
    if (!fp_image || !fp_label) return 1;
    fseek(fp_image, 16, SEEK_SET);
    fseek(fp_label, 8, SEEK_SET);
    fread(data, sizeof(*data) * count, 1, fp_image);
    fread(label, count, 1, fp_label);
    fclose(fp_image);
    fclose(fp_label);
    return 0;
}

static void Initial(LeNet5* lenet)
{
    //srand((unsigned)time(0));
    for (float* pos = (float*)lenet->weight0_1; pos < (float*)lenet->bias0_1; *pos++ = rand() * (2.f / RAND_MAX) - 1);
    for (float* pos = (float*)lenet->weight0_1; pos < (float*)lenet->weight2_3; *pos++ *= sqrtf(6.f / (LENGTH_KERNEL * LENGTH_KERNEL * (LAYER0 + LAYER1))));
    for (float* pos = (float*)lenet->weight2_3; pos < (float*)lenet->weight4_5; *pos++ *= sqrtf(6.f / (LENGTH_KERNEL * LENGTH_KERNEL * (LAYER2 + LAYER3))));
    for (float* pos = (float*)lenet->weight4_5; pos < (float*)lenet->weight5_6; *pos++ *= sqrtf(6.f / (LENGTH_KERNEL * LENGTH_KERNEL * (LAYER4 + LAYER5))));
    for (float* pos = (float*)lenet->weight5_6; pos < (float*)lenet->bias0_1; *pos++ *= sqrtf(6.f / (LAYER5 + LAYER6)));
    for (int* pos = (int*)lenet->bias0_1; pos < (int*)(lenet + 1); *pos++ = 0);
}

int testing(DeviceContext* lenet, image* test_data, uint8_t* test_label, int total_size)
{
    int right = 0, percent = 0;
    Feature feature = { 0 };
    for (int i = 0; i < total_size; ++i)
    {
        uint8_t l = test_label[i];
        LoadInputToFeature(&feature, test_data[i]);
        int p = Predict(lenet, &feature);
        right += l == p;
    }
    return right;
}

void training(DeviceContext* ctx, TrainCache* cache, image* train_data, uint8_t* train_label, int total_size)
{
    Feature* feature = (Feature*)calloc(cache->batchSize, sizeof(Feature));
    uint32_t* label = (uint32_t*)calloc(cache->batchSize, sizeof(uint32_t));
    for (uint32_t i = 0, percent = 0; i <= total_size - cache->batchSize; i += cache->batchSize)
    {
        for (uint32_t j = 0; j < cache->batchSize; ++j)
        {
            LoadInputToFeature(feature + j, train_data[i + j]);
            label[j] = train_label[i + j];
        }
        TrainBatch(ctx, cache, feature, label);
        if (i * 100 / total_size > percent)
            printf("batchsize:%d\ttrain:%2d%%\n", cache->batchSize, percent = i * 100 / total_size);
    }
    free(feature);
    free(label);
}

int main(int argc, const char* argv[]) 
{
    image* train_data = (image*)calloc(COUNT_TRAIN, sizeof(image));
    uint8_t* train_label = (uint8_t*)calloc(COUNT_TRAIN, sizeof(uint8_t));
    image* test_data = (image*)calloc(COUNT_TEST, sizeof(image));
    uint8_t* test_label = (uint8_t*)calloc(COUNT_TEST, sizeof(uint8_t));
    if (LoadDataFromFile(train_data, train_label, COUNT_TRAIN, FILE_TRAIN_IMAGE, FILE_TRAIN_LABEL))
    {
        printf("ERROR!!!\nDataset File Not Find!Please Copy Dataset to the Floder Included the exe\n");
        free(train_data);
        free(train_label);
        system("pause");
    }
    if (LoadDataFromFile(test_data, test_label, COUNT_TEST, FILE_TEST_IMAGE, FILE_TEST_LABEL))
    {
        printf("ERROR!!!\nDataset File Not Find!Please Copy Dataset to the Floder Included the exe\n");
        free(test_data);
        free(test_label);
        system("pause");
    }


    LeNet5* lenet = (LeNet5*)malloc(sizeof(LeNet5));
    if (LoadModelFromFile(lenet, LENET_FILE))
    {
        Initial(lenet);
    }

    DeviceContext devctx;
    TrainCache cache;
    CreateDeviceContext(&devctx);
    CreateTrainCache(&devctx, &cache, TRAIN_BATCH);
    LoadModel(&devctx, lenet);
    {
        time_t start = time(0);
        training(&devctx, &cache, train_data, train_label, COUNT_TRAIN);
        printf("TRAIN TIME:%llus\n", time(0) - start);
    }
    {
        time_t start = time(0);
        int right = testing(&devctx, test_data, test_label, COUNT_TEST);
        printf("%d/%d TEST TIME:%llus\n", right, COUNT_TEST, time(0) - start);
    }
    SaveModel(&devctx, lenet);
    //SaveModelToFile(lenet, LENET_FILE);
    DestroyTrainCache(&devctx, &cache);
    DestroyDeviceContext(&devctx);

    free(lenet);
    free(train_data);
    free(train_label);
    free(test_data);
    free(test_label);
    return 0;
}

